import os, cv2, paddle
import numpy as np
import xml.etree.ElementTree as et
import multiprocessing as mp

class VOCDataset(paddle.io.Dataset):
    def __init__(self, lists_txt, label_txt, transforms):
        """
        初始化数据集
        params:
        - lists_txt : 列表文件
        - label_txt : 标签文件
        - transforms: 数据增强
        """
        super().__init__()
        
        # 设置参数
        if not os.path.exists(lists_txt): # 文件是否存在
            print(f'错误：{lists_txt}不存在！')
            return
        
        self.lists_txt = lists_txt # 列表文件
        self.label_txt = label_txt # 标签文件
        self.lists_dir = os.path.split(lists_txt)[0] # 列表目录
        self.cname2cid = self.get_cname2cid() # 标签字典
        
        # 读取列表
        self.img_list = [] # 图片路径列表
        self.ann_list = [] # 标注路径列表
        
        with open(self.lists_txt, 'r') as f: # 打开列表文件
            for line in f.readlines(): # 遍历每行记录
                img_path, ann_path = line.strip().split() # 提取一行记录
                self.img_list.append(os.path.join(self.lists_dir, img_path)) # 添加图片路径
                self.ann_list.append(os.path.join(self.lists_dir, ann_path)) # 添加标注路径
                
        # 数据增强
        self.transforms = transforms
        
    def __getitem__(self, index):
        """
        获取一项数据
        params:
        - index: 数据索引
        return:
        - data : 图像数据，目标类别，边框位置，图像高宽
        """
        # 读取图片
        with open(self.img_list[index], 'rb') as f: # 打开图片文件
            image = f.read() # 读取图片数据
        image = np.frombuffer(image, dtype='uint8') # 读到ndarray缓存
        image = cv2.imdecode(image, 1) # 解码为3通道图片
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 转换为RGB图片
        
        # 读取标注
        annotation = et.parse(self.ann_list[index]) # 解析标注
        
        img_h = float(annotation.find('size').find('height').text) # 图像高度
        img_w = float(annotation.find('size').find('width').text) # 图像宽度
        if img_h < 0 or img_w < 0 :
            print(f'警告：{self.ann_list[index]}文件：height:{img_h}或者width:{img_w}小于零')
        if img_h != float(image.shape[0]):
            print(f'警告：{self.ann_list[index]}文件：height:{img_h}不等于图片实际height:{float(image.shape[0])}')
            img_h = float(image.shape[0])
        if img_w != float(image.shape[1]):
            print(f'警告：{self.ann_list[index]}文件：width:{img_w}不等于图片实际width:{float(image.shape[1])}')
            img_w = float(image.shape[1])
        imghw = np.array([img_h, img_w], dtype='float32') # 图像高宽

        object_list = annotation.findall('object') # 目标列表
        gtcls = np.zeros((len(object_list),  ), dtype='int32')   # 目标类别
        gtbox = np.zeros((len(object_list), 4), dtype='float32') # 目标边框
        for object_id, object_item in enumerate(object_list):
            # 获取目标类别
            cname = object_item.find('name').text # 类别名称

            # 获取目标边框
            x1 = float(object_item.find('bndbox').find('xmin').text)
            y1 = float(object_item.find('bndbox').find('ymin').text)
            x2 = float(object_item.find('bndbox').find('xmax').text)
            y2 = float(object_item.find('bndbox').find('ymax').text)

            x1 = max(0, x1)
            y1 = max(0, y1)
            x2 = min(x2, img_w - 1)
            y2 = min(y2, img_h - 1)
            
            # 设置类别边框
            if x2 > x1 and y2 > y1:
                gtcls[object_id] = self.cname2cid[cname] # 目标类别
                gtbox[object_id] = [x1, y1, x2, y2] # 目标边框
            else:
                print(f'警告：{self.ann_list[index]}文件：无效的目标边框')
        
        # 增强数据
        data = {'image': image, 'gtcls': gtcls, 'gtbox': gtbox, 'imghw': imghw} # 构造数据字典
        data = self.transforms(data) # 使用数据增强
            
        return data
    
    def __len__(self):
        """
        返回数据总数
        """
        return len(self.img_list)
    
    def get_cname2cid(self):
        """
        获取标签字典
        return:
        - cname2cid: 标签字典
        """
        # 设置标签字典        
        if not os.path.exists(self.label_txt): # 文件是否存在
            print(f'错误：{self.label_txt}不存在！')
            return
        
        cname2cid = {} # 标签字典
        with open(self.label_txt, 'r') as f:
            for cid, cname in enumerate(f.readlines()):
                cname2cid[cname.strip()] = cid
    
        return cname2cid

class Compose():
    def __init__(self, transforms):
        """
        初始图像增强方法
        params:
        - transforms: 图像增强方法列表
        """
        self.transforms = transforms
        
    def __call__(self, data):
        """
        调用图像增强方法
        params:
        - data: 待处理的数据字典
        return:
        - data: 增强后的数据字典
        """
        for f in self.transforms:
            data = f(data)
        
        return data

class RandomDistort():
    def __init__(self, 
                 hue=[-18, 18, 0.5], 
                 saturation=[0.5, 1.5, 0.5], 
                 contrast=[0.5, 1.5, 0.5], 
                 brightness=[0.5, 1.5, 0.5]):
        """
        初始变换图像
        params:
        - hue       : 色调
        - saturation: 饱和度
        - contrast  : 对比度
        - brightness: 明亮度
        """
        self.hue = hue
        self.saturation = saturation
        self.contrast = contrast
        self.brightness = brightness
        
    def __call__(self, data):
        """
        随机变换图像
        params:
        - data: 待处理的数据字典
        return:
        - data: 增强后的数据字典
        """
        # 打乱变换顺序
        distortions = [self.random_hue, self.random_saturation, self.random_contrast, self.random_brightness] # 图像变换列表
        distortions = np.random.permutation(distortions) # 打乱变换方法
        
        # 随机变换图像
        image = data['image'] # 读取图像
        for f in distortions:
            image = f(image) # 变换图像
        data['image'] = image # 保存图像
        
        return data
    
    def random_hue(self, image):
        """
        随机变换色调
        - data: 待处理的图像
        return:
        - data: 增强后的图像
        """
        # 读取变换阈值
        low, high, prob = self.hue
        if np.random.uniform(0, 1) < prob:
            return image
        
        # 随机变换图像
        image = image.astype('float32')
        delta = np.random.uniform(low, high)
        
        u = np.cos(delta * np.pi)
        w = np.sin(delta * np.pi)
        bt = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]])
        tyiq = np.array([[0.299, 0.587, 0.114], [0.596, -0.274, -0.321], [0.211, -0.523, 0.311]])
        ityiq = np.array([[1.0, 0.956, 0.621], [1.0, -0.272, -0.647], [1.0, -1.107, 1.705]])
        
        t = np.dot(np.dot(ityiq, bt), tyiq).T
        image = np.dot(image, t)
        image = image.astype('uint8')
        
        return image
    
    def random_saturation(self, image):
        """
        随机变换饱和度
        - image: 待处理的图像
        return:
        - image: 增强后的图像
        """
        # 读取变换阈值        
        low, high, prob = self.saturation
        if np.random.uniform(0, 1) < prob:
            return image
        
        # 随机变换图像
        image = image.astype('float32')
        delta = np.random.uniform(low, high)
        
        gray = image * np.array([[[0.299, 0.587, 0.114]]], dtype='float32')
        gray = gray.sum(axis=2, keepdims=True)
        gray *= (1.0 - delta)
        
        image *= delta
        image += gray
        image = image.astype('uint8')
        
        return image
    
    def random_contrast(self, image):
        """
        随机变换对比度
        - image: 待处理的图像
        return:
        - image: 增强后的图像
        """
        # 读取变换阈值        
        low, high, prob = self.saturation
        if np.random.uniform(0, 1) < prob:
            return image
        
        # 随机变换图像
        image = image.astype('float32')
        delta = np.random.uniform(low, high)
        
        image *= delta
        image = image.astype('uint8')
        
        return image
    
    def random_brightness(self, image):
        """
        随机变换明亮度
        - image: 待处理的图像
        return:
        - image: 增强后的图像
        """
        # 读取变换阈值        
        low, high, prob = self.saturation
        if np.random.uniform(0, 1) < prob:
            return image
        
        # 随机变换图像
        image = image.astype('float32')
        delta = np.random.uniform(low, high)
        
        image += delta
        image = image.astype('uint8')
        
        return image

class RandomExpand():
    def __init__(self, 
                 ratio=4.0, 
                 prob=0.5, 
                 fill_value=(127.5, 127.5, 127.5)):
        """
        初始扩大图像
        params:
        - ratio     : 扩大比例
        - prob      : 扩大概率
        - fill_value: 填充颜色，RBG格式
        """
        assert ratio > 1.01, '扩大比例必须大于1.01'
        assert isinstance(fill_value, tuple), '填充颜色必须为3元组'
        self.ratio = ratio
        self.prob = prob
        self.fill_value = fill_value
        
    def __call__(self, data):
        """
        随机扩大图像
        params:
        - data: 待处理的数据字典
        return:
        - data: 增强后的数据字典
        """
        # 设置变换数据
        if np.random.uniform(0, 1) < self.prob:
            return data
        
        image = data['image'] # 图像数据
        img_h, img_w, img_c = image.shape # 图像信息
        ratio = np.random.uniform(1, self.ratio) # 扩大比例
        
        h = int(img_h * ratio) # 扩大高度
        w = int(img_w * ratio) # 扩大宽度
        if not h > img_h or not w > img_w:
            return data
        y = np.random.randint(0, h - img_h) # 原图x坐标
        x = np.random.randint(0, w - img_w) # 原图y坐标
        
        # 设置扩大图像
        image = image.astype('float32') # 转换数据格式
        canvas = np.ones((h, w, img_c), dtype='float32') # 创建扩大画布
        canvas *= np.array(self.fill_value, dtype='float32') # 填充画布颜色
        canvas[y:y+img_h, x:x+img_w, :] = image # 填充原始图片
        
        data['image'] = canvas.astype('uint8') # 设置扩大图片
        data['gtbox'] += np.array([x, y] * 2, dtype='float32') # 设置扩大边框
        
        return data
    
class RandomCrop():
    def __init__(self, 
                 aspect_ratio=[0.5, 2.0],
                 thresholds=[0.0, 0.1, 0.3, 0.5, 0.7, 0.9], 
                 scaling=[0.3, 1.0],
                 num_attempts=50,
                 allow_no_crop=True,
                 cover_all_box=False):
        """
        初始随机裁剪
        params:
        - aspect_ratio : 裁剪图像高宽比
        - thresholds   : 裁剪边框与原图边框的交并比
        - scaling      : 裁剪图像与原始图像的缩放比
        - num_attempts : 最大裁剪尝试数
        - allow_no_crop: 允许返回没有裁剪的图像
        - cover_all_box: 裁剪图像包含全部的目标边框
        """
        self.aspect_ratio = aspect_ratio
        self.thresholds = thresholds
        self.scaling = scaling
        self.num_attempts = num_attempts
        self.allow_no_crop = allow_no_crop
        self.cover_all_box = cover_all_box
        
    def __call__(self, data):
        """
        随机裁剪图像
        params:
        - data: 待处理的数据字典
        return:
        - data: 增强后的数据字典
        """
        # 获取裁剪信息
        if 'gtbox' in data and len(data['gtbox']) == 0: # 是否存在目标边框
            return data
        h, w = data['image'].shape[:2] # 图像高宽
        gtbox = data['gtbox'] # 目标边框
        
        # 打乱裁剪阈值
        thresholds = list(self.thresholds)
        if self.allow_no_crop:
            thresholds.append('no_crop') # 添加不用裁剪标识
        np.random.shuffle(thresholds) # 打乱裁剪交并比值
        
        # 计算裁剪边框
        for thresh in thresholds: # 遍历裁剪交并比值
            # 是否不需裁剪
            if thresh == 'no_crop':
                return data
            
            # 获取裁剪边框
            found = False # 是否找到裁剪边框
            for i in range(self.num_attempts): # 尝试裁剪次数
                # 计算缩放比例
                scale = np.random.uniform(*self.scaling) # 裁剪图像与原始图像的缩放比
                if self.aspect_ratio is not None: # 是否保持高宽比
                    min_ar, max_ar = self.aspect_ratio # 设置最小最大高宽比
                    aspect_ratio = np.random.uniform(
                        max(min_ar, scale**2), min(max_ar, scale**-2)) # 计算高宽比
                    h_scale = scale / np.sqrt(aspect_ratio) # 计算高度缩放比
                    w_scale = scale * np.sqrt(aspect_ratio) # 计算宽度缩放比
                else: # 否则设置随机高宽缩放比
                    h_scale = np.random.uniform(*self.scaling) # 设置高度缩放比
                    w_scale = np.random.uniform(*self.scaling) # 设置宽度缩放比
                    
                # 计算裁剪高宽
                crop_h = h * h_scale # 裁剪图像高度
                crop_w = w * w_scale # 裁剪图像宽度
                if self.aspect_ratio is None: # 是否裁剪高宽比在[0.5,2.0]之间
                    if crop_h / crop_w < 0.5 or crop_h / crop_w > 2.0:
                        continue
                
                # 计算裁剪边框
                crop_h = int(crop_h)
                crop_w = int(crop_w)
                crop_y = np.random.randint(0, h - crop_h)
                crop_x = np.random.randint(0, w - crop_w)
                crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h] # 裁剪图像坐标x1y2x2y2
                
                # 计算交并比值
                iou = self._iou_matrix(
                    gtbox, np.array([crop_box], dtype='float32')) # 目标图像边框与裁剪图像边框交并比值
                if iou.max() < thresh: # 是否交并比最大值小于当前交并比值
                    continue
                if self.cover_all_box and iou.min() < thresh: # 是否包含全部边框，交并比最小值小于当前交并比值
                    continue
                    
                # 获取裁剪边框
                cpbox, valid_index = self._crop_box_with_center_constraint(
                    gtbox, np.array(crop_box, dtype='float32'))
                if valid_index.size > 0: # 是否存在有效裁剪边框
                    found = True # 设置找到裁剪边框
                    break
            
            # 裁剪图像边框
            if found: # 是否找到裁剪边框
                data['image'] = self._crop_image(data['image'], crop_box) # 裁剪图像
                data['gtbox'] = np.take(cpbox, valid_index, axis=0) # 裁剪边框
                return data
        
        return data
    
    def _iou_matrix(self, gtbox, crop_box):
        """
        计算矩阵交并比值
        params:
        - a  : 目标边框
        - b  : 裁剪边框
        return:
        - iou: 交并比值
        """
        # 计算交集面积
        tl_i = np.maximum(gtbox[:, np.newaxis, :2], crop_box[:, :2]) # 左上角最大点
        br_i = np.minimum(gtbox[:, np.newaxis, 2:], crop_box[:, 2:]) # 右下角最小点
        area_i = np.prod(br_i - tl_i, axis=2) * (tl_i < br_i).all(axis=2) # 交集面积
        
        # 计算并集面积
        area_a = np.prod(gtbox[:, 2:] - gtbox[:, :2], axis=1) # 目标边框面积
        area_b = np.prod(crop_box[:, 2:] - crop_box[:, :2], axis=1) # 裁剪边框面积
        area_o = area_a[:, np.newaxis] + area_b - area_i # 并集面积
        
        # 计算交并比值
        iou = area_i / (area_o + 1e-10)
        
        return iou
    
    def _crop_box_with_center_constraint(self, gtbox, crop_box):
        """
        设置裁剪边框
        params:
        - gtbox      : 目标边框
        - crop_box   : 裁剪边框
        return:
        - cpbox      : 裁剪边框
        - valid_index: 有效索引
        """
        # 计算裁剪边框
        cpbox = gtbox.copy() # 拷贝目标边框
        
        cpbox[:, :2] = np.maximum(gtbox[:, :2], crop_box[:2]) # 设置x1y1坐标
        cpbox[:, 2:] = np.minimum(gtbox[:, 2:], crop_box[2:]) # 设置x2y2坐标
        cpbox[:, :2] -= crop_box[:2] # 设置裁剪后的x1y1坐标
        cpbox[:, 2:] -= crop_box[:2] # 设置裁剪后的x2y2坐标
        
        # 获取有效索引
        centers = (gtbox[:, :2] + gtbox[:, 2:]) / 2 # 计算目标边框中心位置
        valid = np.logical_and(crop_box[:2] <= centers, centers < crop_box[2:]).all(axis=1) # 边框中心点是否有效
        valid = np.logical_and(valid, (cpbox[:, :2] < cpbox[:, 2:]).all(axis=1) ) # 边框左上点是否小于右下点
        valid_index = np.where(valid)[0] # 获取有效索引
        
        return cpbox, valid_index
    
    def _crop_image(self, image, crop_box):
        """
        裁剪图像
        params:
        - image   : 原始图像
        - crop_box: 裁剪边框
        return:
        - image   : 裁剪图像
        """
        x1, y1, x2, y2 = crop_box
        return image[y1:y2, x1:x2, :]
        
class RandomFlip():
    def __init__(self, prob=0.5):
        """
        初始翻转概率
        params:
        - prob: 翻转概率
        """
        self.prob = prob
        
    def __call__(self, data):
        """
        随机水平翻转
        params:
        - data: 待处理的数据字典
        return:
        - data: 增强后的数据字典
        """
        # 翻转图像
        if np.random.uniform(0, 1) < self.prob:
            return data
        
        image = data['image'] # 设置图像数据
        image = image[:, ::-1, :] # 水平翻转图像
        data['image'] = image # 设置翻转图像
        
        # 翻转边框
        gtbox = data['gtbox'] # 目标边框
        h, w = image.shape[:2] # 图像高宽
        
        x1 = gtbox[:, 0].copy() # 拷贝x1坐标
        x2 = gtbox[:, 2].copy() # 拷贝x2坐标
        gtbox[:, 0] = w - x2 # 设置x1坐标
        gtbox[:, 2] = w - x1 # 设置x2坐标

        data['gtbox'] = gtbox # 设置翻转边框
        
        return data

####################################################################################

MAIN_PID = os.getpid() # 主进程号

class BatchCompose():
    def __init__(self, batch_transforms):
        """
        初始批次图像增强方法
        params:
        - batch_transforms: 批次图像增强方法列表
        """
        self.batch_transforms = batch_transforms
        self.lock = mp.Lock() # 进程锁
        self.output_fields = mp.Manager().list([]) # 输出键名列表
        
    def __call__(self, batch_data):
        """
        调用批次图像增强方法
        params:
        - batch_data: 待处理的批次字典列表
        return:
        - batch_data: 增强后的批次数据列表
        """
        # 批次增强图像
        for f in self.batch_transforms:
            batch_data = f(batch_data)

        # 获取输出键名
        global MAIN_PID # 全局主进程号
        if os.getpid() == MAIN_PID and isinstance(self.output_fields, mp.managers.ListProxy): # 是否为主进程
            self.output_fields = [] # 主进程中不使用共享列表
            
        if len(self.output_fields) == 0: # 是否输出键名为空
            self.lock.acquire() # 进程加锁
            if len(self.output_fields) == 0: # 是否输出键名为空
                for k, v in batch_data[0].items(): # 遍历列表中的字典数据
                    self.output_fields.append(k) # 获取输出键名
            self.lock.release() # 进程解锁
        
        # 打包批次数据
        batch_data = [[batch_data[i][k] for k in self.output_fields] for i in range(len(batch_data))] # 获取字典数据键值
        batch_data = list(zip(*batch_data)) # 打包元组列表，并转换为元组列表
        batch_data = [np.stack(data, axis=0) for data in batch_data] # 将元组列表转换为ndarray列表
        
        return batch_data
    
class RandomResize():
    def __init__(self, dsize, random_dsize=False, random_interpolation=False):
        """
        初始随机缩放
        params:
        - dsize               : 缩放尺寸列表
        - random_dsize        : 是否随机尺寸
        - random_interpolation: 是否随机插值
        """
        self.dsize = dsize
        self.interpolation = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, 
                              cv2.INTER_CUBIC, cv2.INTER_LANCZOS4] # 插值方法列表
        self.random_dsize = random_dsize
        self.random_interpolation = random_interpolation
        
    def __call__(self, batch_data):
        """
        随机缩放图像
        params:
        - batch_data: 待处理的批次字典列表
        return:
        - batch_data: 增强后的批次字典列表
        """
        # 随机选择缩放
        if self.random_dsize:         # 是否随机尺寸
            dsize = np.random.choice(self.dsize) # 随机选择缩放尺寸
        else:                         # 否则选择第一个缩放尺寸
            dsize = self.dsize[0]
            
        if self.random_interpolation: # 是否随机插值
            interpolation = np.random.choice(self.interpolation) # 随机选择插值方法
        else:                         # 否则选择第一个插值方法
            interpolation = self.interpolation[0]
        
        # 批次缩放图像
        for data in batch_data:
            h, w = data['image'].shape[:2] # 获取原图高宽
            fy, fx = dsize / h, dsize / w # 缩放高宽系数
            
            data['image'] = self.resize_image(data['image'], dsize, interpolation)
            data['gtbox'] = self.resize_gtbox(data['gtbox'], dsize, [fx, fy])
        
        return batch_data
    
    def resize_image(self, image, dsize, interpolation):
        """
        缩放图像
        params:
        - image        : 原始图像
        - scale        : 缩放尺寸
        - interpolation: 插值方法
        return:
        - image        : 缩放图像
        """
        image = cv2.resize(image, (dsize, dsize), interpolation=interpolation)
        return image
    
    def resize_gtbox(self, gtbox, dsize, scale):
        """
        缩放目标边框
        params:
        - gtbox: 原始边框
        - dsize: 缩放尺寸
        - scale: 缩放系数
        return:
        - gtbox: 缩放边框
        """
        gtbox[:, 0::2] *= scale[0] # 设置x1x2坐标
        gtbox[:, 1::2] *= scale[1] # 设置y1y2坐标
        gtbox[:, 0::2] = np.clip(gtbox[:, 0::2], 0, dsize) # 截取x1x2坐标
        gtbox[:, 1::2] = np.clip(gtbox[:, 1::2], 0, dsize) # 截取y1y2坐标
        
        return gtbox
    
class PadClassBox():
    def __init__(self, num_max):
        """
        初始目标填充数量
        params:
        - num_max: 最大填充数量
        """
        self.num_max = num_max
        
    def __call__(self, batch_data):
        """
        填充目标类别边框
        params:
        - batch_data: 待处理的批次字典列表
        return:
        - batch_data: 增强后的批次字典列表
        """
        for data in batch_data:
            # 填充目标类别
            gtcls = data['gtcls'] # 获取目标类别
            gtnum = min(self.num_max, len(gtcls)) # 设置类别数量
            pad_cls = np.zeros((self.num_max,  ), dtype='int32') # 设置填充类别
            if gtnum > 0: # 是否类别数量大于0
                pad_cls[:gtnum] = gtcls[:gtnum] # 填充类别矩阵
            data['gtcls'] = pad_cls # 设置目标类别
            
            # 填充目标边框
            gtbox = data['gtbox'] # 获取目标边框
            gtnum = min(self.num_max, len(gtbox)) # 设置边框数量
            pad_box = np.zeros((self.num_max, 4), dtype='float32') # 设置填充边框
            if gtnum > 0: # 是否边框数量大于0
                pad_box[:gtnum, :] = gtbox[:gtnum, :] # 填充边框矩阵
            data['gtbox'] = pad_box # 设置目标边框
        
        return batch_data

class BoxNormalize():    
    def __call__(self, batch_data):
        """
        把目标边框归一化到[0,1]
        params:
        - batch_data: 待处理的批次字典列表
        return:
        - batch_data: 增强后的批次字典列表
        """
        for data in batch_data:
            # 获取目标边框
            gtbox = data['gtbox'] # 目标边框
            h, w = data['image'].shape[:2] # 图像高宽
            
            # 归一目标边框
            for i in range(gtbox.shape[0]): # 遍历目标边框
                gtbox[i][0] = gtbox[i][0] / w # 设置x1坐标
                gtbox[i][1] = gtbox[i][1] / h # 设置y1坐标
                gtbox[i][2] = gtbox[i][2] / w # 设置x2坐标
                gtbox[i][3] = gtbox[i][3] / h # 设置y2坐标
            
            # 设置目标边框
            data['gtbox'] = gtbox
            
        return batch_data
    
class BoxXYXY2XYWH():
    def __call__(self, batch_data):
        """
        把目标边框从XYXY格式变换为XYWH格式
        params:
        - batch_data: 待处理的批次字典列表
        return:
        - batch_data: 增强后的批次字典列表
        """
        for data in batch_data:
            # 获取目标边框
            gtbox = data['gtbox']
            
            # 变换边框格式
            gtbox[:, 2:4] = gtbox[:, 2:4] - gtbox[:, :2] # 计算边框宽高
            gtbox[:, :2] = gtbox[:, :2] + gtbox[:, 2:4] / 2.0 # 计算中心坐标
            
            # 设置目标边框
            data['gtbox'] = gtbox
            
        return batch_data
    
class ImageNormalize():
    def __init__(self, mean=[0.485, 0.456, 0.406], stdv=[0.229, 0.224, 0.225]):
        """
        初始化图像归一化均值方差
        params:
        - mean: 数据集通道平均值
        - stdv: 数据集通道标准差
        """
        self.mean = mean
        self.stdv = stdv
    
    def __call__(self, batch_data):
        """
        把图像边框归一化到[0,1]
        params:
        - batch_data: 待处理的批次字典列表
        return:
        - batch_data: 增强后的批次字典列表
        """
        for data in batch_data:
            # 获取图像数据
            image = data['image']
            image = image.astype('float32', copy=False) # 转换数据格式
            
            # 归一图像数据
            mean = np.array(self.mean, dtype='float32')[np.newaxis, np.newaxis, :] # 生成均值矩阵
            stdv = np.array(self.stdv, dtype='float32')[np.newaxis, np.newaxis, :] # 生成方差矩阵
            image = ((image/255.0) - mean) / stdv # 归一化到[0,1]
            
            # 设置图像数据
            data['image'] = image
            
        return batch_data

class ImagePermute():
    def __call__(self, batch_data):
        """
        把图像通道从HWC变换为CHW
        params:
        - batch_data: 待处理的批次字典列表
        return:
        - batch_data: 增强后的批次字典列表
        """
        for data in batch_data:
            # 获取图像数据
            image = data['image']
            
            # 变换通道位置
            image = image.transpose((2, 0, 1))
            
            # 设置图像数据
            data['image'] = image
            
        return batch_data

####################################################################################    

class DataLoader():
    def __init__(self, lists_txt, label_txt, batch_size=1, worker_num=0, mode='train'):
        """
        初始化数据加载器
        params:
        - lists_txt : 列表文件路径
        - label_txt : 标签文件路径
        - batch_size: 批次数据大小
        - worker_num: 读取子线程数
        - mode      : 数据读取模式
        """
        # 输入参数检测
        assert batch_size > 0, '错误：批次数据大小必须大于0!'
        assert worker_num >= 0, '错误：读取子线程数大于等于0!'
        assert mode in ['train', 'valid'], '错误：数据读取模式必须为"train"或"valid"!'
        
        # 初始化数据集
        if mode == 'train': # 是否为训练模式
            self.transforms = Compose([RandomDistort(), # 随机变换图像
                                       RandomExpand(fill_value=(123.675, 116.28, 103.53)), # 随机扩大图像
                                       RandomCrop(), # 随机裁剪图像
                                       RandomFlip() # 随机水平翻转
                                      ])  # 单个数据增强
        else:               # 否则为验证模型
            self.transforms = Compose([]) # 单个数据增强
        
        self.dataset = VOCDataset(lists_txt, label_txt, self.transforms) # 初始化数据集

        # 初始化迭代器
        if mode == 'train': # 是否为训练模式
            self.batch_sampler = paddle.io.DistributedBatchSampler(dataset=self.dataset,
                                                                   batch_size=batch_size,
                                                                   shuffle=True, # 打乱批次数据
                                                                   drop_last=False) # 批次数据采样
            self.batch_transforms = BatchCompose([RandomResize(dsize=[320, 352, 384, 416, 448, 
                                                                      480, 512, 544, 576, 608],
                                                               random_dsize=True, 
                                                               random_interpolation=True), # 随机缩放图像
                                                  PadClassBox(num_max=50), # 填充类别边框
                                                  BoxNormalize(), # 归一目标边框
                                                  BoxXYXY2XYWH(), # 变换目标边框
                                                  ImageNormalize(mean=[0.485, 0.456, 0.406], 
                                                                 stdv=[0.229, 0.224, 0.225]), # 归一图像数据
                                                  ImagePermute() # 变换图像通道
                                                 ]) # 批次数据增强
        else:               # 否则为验证模型
            self.batch_sampler = paddle.io.DistributedBatchSampler(dataset=self.dataset,
                                                                   batch_size=batch_size,
                                                                   shuffle=False,
                                                                   drop_last=False) # 批次数据采样
            self.batch_transforms = BatchCompose([RandomResize(dsize=[608],
                                                               random_dsize=False, 
                                                               random_interpolation=False), # 缩放图像大小
                                                  PadClassBox(num_max=50), # 填充类别边框
                                                  BoxNormalize(), # 归一目标边框
                                                  BoxXYXY2XYWH(), # 变换目标边框
                                                  ImageNormalize(mean=[0.485, 0.456, 0.406], 
                                                                 stdv=[0.229, 0.224, 0.225]), # 归一图像数据
                                                  ImagePermute() # 变换图像通道
                                                 ]) # 批次数据增强
        
        self.batch_loader = paddle.io.DataLoader(dataset=self.dataset,
                                                 batch_sampler=self.batch_sampler,
                                                 collate_fn=self.batch_transforms,
                                                 num_workers=worker_num,
                                                 return_list=False,
                                                 use_buffer_reader=True,
                                                 use_shared_memory=False) # 初始化迭代器
        self.iter_loader = iter(self.batch_loader) # 创建可迭代器
        
    def __iter__(self):
        """
        返回可迭代对象
        """
        return self
    
    def __next__(self):
        """
        取出下一条数据
        """
        try:                                                                                    # 尝试读取数据
            batch_data = next(self.iter_loader)                                                 # 取出一条数据
            batch_data = {k:v for k, v in zip(self.batch_transforms.output_fields, batch_data)} # 列表打包字典
            return batch_data
        except StopIteration:                                                                   # 捕获停止异常
            self.iter_loader = iter(self.batch_loader)                                          # 重置可迭代器
            raise StopIteration                                                                 # 抛出停止异常
            
    def __len__(self):
        """
        返回数据总数
        """
        return len(self.batch_sampler)